{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "Capsule Networks",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AYV_dMVDxyc2"
   },
   "source": [
    "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/capsule_networks/mnist.ipynb)                    \n",
    "\n",
    "## Training a Capsule Network to classify MNIST digits\n",
    "\n",
    "This is an experiment to train a Capsule Network to classify MNIST digits using PyTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AahG_i2y5tY9"
   },
   "source": [
    "Install the `labml-nn` package"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ZCzmCrAIVg0L",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "7ab15f72-c99f-4097-ecd2-5740ee9ed61c"
   },
   "source": [
    "!pip install labml-nn"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SE2VUQ6L5zxI"
   },
   "source": [
    "Imports"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "0hJXx_g0wS2C"
   },
   "source": [
    "import torch\n",
    "\n",
    "from labml import experiment\n",
    "from labml_nn.capsule_networks.mnist import Configs"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lpggo0wM6qb-"
   },
   "source": [
    "Create an experiment"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bFcr9k-l4cAg"
   },
   "source": [
    "experiment.create(name=\"capsule_networks\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-OnHLi626tJt"
   },
   "source": [
    "Initialize [Capsule Network configurations](https://nn.labml.ai/capsule_networks/mnist.html)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Piz0c5f44hRo"
   },
   "source": [
    "conf = Configs()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wwMzCqpD6vkL"
   },
   "source": [
    "Set experiment configurations and assign a configurations dictionary to override configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17
    },
    "id": "e6hmQhTw4nks",
    "outputId": "ebefa8fa-93d2-4131-db95-e27f15aa3aa0"
   },
   "source": [
    "experiment.configs(conf, {'optimizer.optimizer': 'Adam',\n",
    "                         'optimizer.learning_rate': 1e-3,\n",
    "                         'inner_iterations': 5})"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EvI7MtgJ61w5"
   },
   "source": [
    "Set PyTorch models for loading and saving"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "id": "GDlt7dp-5ALt",
    "outputId": "9701092b-c88a-4687-c90e-b193c369e59e"
   },
   "source": [
    "experiment.add_pytorch_models({'model': conf.model})"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KJZRf8527GxL"
   },
   "source": [
    "Start the experiment and run the training loop."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 646
    },
    "id": "aIAWo7Fw5DR8",
    "outputId": "5ddbfce3-91f8-4506-e483-1640cb5a14b3"
   },
   "source": [
    "with experiment.start():\n",
    "    conf.run()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "oBXXlP2b7XZO"
   },
   "source": [
    ""
   ],
   "outputs": [],
   "execution_count": null
  }
 ]
}
